{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 添加ONNXToMindSpore算子映射关系初级教程\n",
    "\n",
    "`Linux` `Ascend` `GPU` `CPU` `模型迁移` `初级`\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/mindinsight/blob/master/ecosystem_tools/mindconverter/tutorial/add_onnx2mindspore_operator_mapper_base_tutorial.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "MindConverter工具基于ONNX模型进行脚本转换，生成MindSpore脚本和权重文件。因此需要ONNX算子到MindSpore算子的映射关系文件来保证算子之间转换结果的正确性。本文将以简单的算子映射关系为例，来描述添加算子映射关系文件的方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 环境准备\n",
    "\n",
    "本案例需安装以下Python三方库：\n",
    "\n",
    "```bash\n",
    "pip install mindspore==1.6.0\n",
    "pip install mindconverter==1.6.0\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 算子映射脚本（base.py)结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import abc\n",
    "\n",
    "\n",
    "class Mapper(metaclass=abc.ABCMeta):\n",
    "    \"\"\"Mapper between third-party-operation and MindSpore.\"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    @abc.abstractmethod\n",
    "    def _operation_name_in_ms(**kwargs):\n",
    "        \"\"\"Corresponding operation name in MindSpore.\"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    @abc.abstractmethod\n",
    "    def _convert_params(**kwargs):\n",
    "        \"\"\"Convert third-party-operation's attributes or weights into MindSpore operation's attributes.\"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    @abc.abstractmethod\n",
    "    def _convert_trained_weights(**kwargs):\n",
    "        \"\"\"Convert third-party-operation's trainable weights into MindSpore operation's.\"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    @abc.abstractmethod\n",
    "    def _generate_snippet_template(**kwargs):\n",
    "        \"\"\"Generate code template according to node info.\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 自定义添加算子映射脚本\n",
    "\n",
    "以`onnx::Gemm`算子为例进行演示。\n",
    "\n",
    "分别查阅[ONNX算子API文档](https://github.com/onnx/onnx/blob/master/docs/Operators.md)和[MindSpore算子API文档](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)，\n",
    "找到与ONNX算子`onnx::Gemm`功能相同或相近的MindSpore算子`mindspore.nn.Dense`。\n",
    "\n",
    "|算子名|`onnx::Gemm`|`mindspore.nn.Dense`|\n",
    "|:----:|:----|:----|\n",
    "|算法实现|`Y = alpha*A'*B'+beta*C`|`output = activation(inputs*kernel+bias)`|\n",
    "|参数|`alpha`: optional<br>`beta`: optional<br>`transA`: optional<br>`transB`: optional|`in_channels`: required<br>`out_channels`: required<br>`weight_init`: optional<br>`bias_init`: optional<br>`has_bias`: optional<br>`activation`: optional|\n",
    "|输入|`A`: required<br>`B`: required<br>`C`: optional|`input`: required|\n",
    "|输出|`Y`|`output`|\n",
    "\n",
    "<br>\n",
    "依据双方算子中参数（Attributes/Parameters）和输入（Inputs）进行ONNX到MindSpore的算子映射。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# 导入Mapper基类\n",
    "from mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper\n",
    "\n",
    "\n",
    "class DenseMapper(ONNXToMindSporeMapper):\n",
    "    \"\"\"Dense mapper.\"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    def _operation_name_in_ms(*args, **kwargs):\n",
    "        return \"nn.Dense\"  # MindSpore中对应的算子名\n",
    "\n",
    "    @staticmethod\n",
    "    def _convert_params(**kwargs):\n",
    "        \"\"\"\n",
    "        参数迁移相关方法，该方法返回的参数将在生成的MindSpore脚本中以\n",
    "        `OP(dict_key_0=dict_value_0, dict_key_1=dict_value_1, ...)`的形式\n",
    "        定义算子，因此需要保证dict_key_x与MindSpore算子中的参数名相同。\n",
    "\n",
    "        Args:\n",
    "            kwargs: Data for converting.\n",
    "            Struct is `{\n",
    "                        'weights': [NodeWeight(), NodeWeight(), ...],\n",
    "                        'params': {\n",
    "                                    'input_shape': input_shape,\n",
    "                                    'output_shape': output_shape,\n",
    "                                    'onnx_attribute_name_0': onnx_attribute_val_0,\n",
    "                                    ...\n",
    "                                  }\n",
    "                        }`\n",
    "        \"\"\"\n",
    "\n",
    "        weights = kwargs['weights']  # 获取ONNX算子的Inputs中的静态Tensor列表\n",
    "        # 获取Tensor列表中指定序列号的Tensor值，其中序列号与ONNX算子中的Inputs顺序一致。\n",
    "        weight = DenseMapper._find_val_by_index(0, weights)\n",
    "        bias = DenseMapper._find_val_by_index(1, weights)\n",
    "        has_bias = isinstance(bias, np.ndarray)\n",
    "        in_channels, out_channels = weight.shape\n",
    "        return {\n",
    "            'in_channels': in_channels,\n",
    "            'out_channels': out_channels,\n",
    "            'has_bias': has_bias\n",
    "        }\n",
    "\n",
    "    @staticmethod\n",
    "    def _convert_trained_weights(**kwargs):\n",
    "        \"\"\"\n",
    "        权重迁移相关方法，该方法返回的权重将会保存在生成的CheckPoint（.ckpt）文件当中\n",
    "        使生成的MindSpore脚本可以直接加载该权重文件用于重训练或推理。\n",
    "        详细的内容可参考进阶篇。\n",
    "        \"\"\"\n",
    "\n",
    "        weights = kwargs['weights']\n",
    "        weight = DenseMapper._find_val_by_index(0, weights)\n",
    "        bias = DenseMapper._find_val_by_index(1, weights)\n",
    "        return {\n",
    "            'weight': {'data': weight},\n",
    "            'bias': {'data': bias}\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将该Mapper脚本命名为`dense_mapper.py`，该命名方式需要和类名（`DenseMapper`）相对应。<br>\n",
    "并放入 `mindconverter/graph_based_converter/mapper/impl/nn`目录下，该放置目录需要根据对应的MindSpore算子所在的层（`nn`/`ops`）来设置。<br>\n",
    "最后修改 `mindconverter/graph_based_converter/mapper/onnx_to_ms.json`，\n",
    "添加 `\"onnx::Gemm\": \"mindconverter.graph_based_converter.mapper.impl.nn.dense_mapper.DenseMapper\"`来确定ONNX算子所对应的Mapper脚本文件。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 验证自定义算子映射脚本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper\n",
    "from mindconverter.graph_based_converter.common.code_fragment import Fragment\n",
    "from mindconverter.graph_based_converter.third_party_graph.base import NodeWeight\n",
    "\n",
    "def test_mapper(onnx_info):\n",
    "    \"\"\"\n",
    "    Test mapper.\n",
    "\n",
    "    Args:\n",
    "        onnx_info (dict): Onnx operator_info. Struct is\n",
    "                                   {\n",
    "                                    'op_name': op_name,\n",
    "                                    'attributes': dict(),\n",
    "                                    'weights': [NodeWeight(), ...]\n",
    "                                   }\n",
    "    \"\"\"\n",
    "\n",
    "    template, exchange_msg, outputs_lists, outputs_mapping = \\\n",
    "        ONNXToMindSporeMapper.convert(onnx_info['op_name'],\n",
    "                                      onnx_info['attributes'],\n",
    "                                      onnx_info['weights'])\n",
    "\n",
    "    exchange_msg['var_0']['variable_name'] = 'self_defined_operator'\n",
    "    exchange_msg['var_0']['inputs'] = ['x']\n",
    "\n",
    "    fragment = Fragment(data_entity=exchange_msg, code_template=template, outputs=outputs_lists,\n",
    "                        outputs_mapping=outputs_mapping)\n",
    "\n",
    "    code = fragment()\n",
    "    init_code = code[0]\n",
    "    construct_code = code[1]\n",
    "    print('-'*30, 'init_code', '-'*30)\n",
    "    print('\\n'.join(init_code))\n",
    "    print('-'*30, 'construct_code', '-'*30)\n",
    "    print('\\n'.join(construct_code))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------ init_code ------------------------------\n",
      "self.self_defined_operator = nn.Dense(in_channels=3, out_channels=10, has_bias=True)\n",
      "------------------------------ construct_code ------------------------------\n",
      "opt_self_defined_operator = self.self_defined_operator(x)\n"
     ]
    }
   ],
   "source": [
    "onnx_operator_info = {'op_name': 'onnx::Gemm',\n",
    "                      'attributes': {'alpha': 1.0,\n",
    "                                     'beta': 1.0,\n",
    "                                     'transA': 0,\n",
    "                                     'transB': 0},\n",
    "                      'weights': [NodeWeight(weight_name='weight',\n",
    "                                             weight_location=1,\n",
    "                                             weight_value=np.ones((10, 3),\n",
    "                                                                  dtype=np.int)),\n",
    "                                  NodeWeight(weight_name='bias',\n",
    "                                             weight_location=2,\n",
    "                                             weight_value=np.ones((10, 3),\n",
    "                                                                  dtype=np.int))]}\n",
    "test_mapper(onnx_operator_info)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}